In previous articles "How Should Learning Rate Scale with Batch Size?" and "How Does Adam's Epsilon Affect Learning Rate Scaling Laws?", we theoretically discussed the scaling relationship between learning rate and batch size. The classical part of this analysis, which expands to second order, was proposed by OpenAI. However, when dealing with non-SGD optimizers, the computational process of this analytical method often becomes quite complex, leaving one feeling somewhat overwhelmed.

In the following series of articles, I will reorganize and reconsider the relevant details from the aforementioned articles, attempting to simplify some of the derivation steps, providing a more general and elegant derivation path, and exploring the possibility of extending this to the Muon optimizer.

Method Overview#

First, let's review the previous analytical approach. In "How Should Learning Rate Scale with Batch Size?", we introduced several analytical perspectives on the relationship between learning rate and batch size. The main focus was on the second-order approximation analysis proposed by OpenAI in "An Empirical Model of Large-Batch Training", and this article follows the same line of thought.

Next, we need to introduce some notation. Let the loss function be $\mathcal{L}(\boldsymbol{w})$, where $\boldsymbol{w}\in\mathbb{R}^N$ is the parameter vector, and $\boldsymbol{g}$ is its gradient. Note that the ideal loss function is computed as an expectation over the entire training dataset, but in practice we can only sample a batch to compute it, which introduces randomness into the gradient. We denote the gradient for a single sample as $\tilde{\boldsymbol{g}}$, whose mean is $\boldsymbol{g}$, and its covariance matrix as $\boldsymbol{\Sigma}$. When the batch size is $B$, the gradient is denoted as $\tilde{\boldsymbol{g}}_B$, whose mean remains $\boldsymbol{g}$, but the covariance matrix becomes $\boldsymbol{\Sigma}/B$.

Furthermore, let the current learning rate be $\eta$, and the update vector be $\tilde{\boldsymbol{\varphi}}_B$. Then the loss function after update will be:

(1) \[ \begin{aligned} \mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B) \approx&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{\varphi}}_B \\[5pt] =&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\newcommand{tr}{\mathop{\text{tr}}}\tr(\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}) \end{aligned} \]

On the right side, we have Taylor expanded to second order, where $\boldsymbol{H}$ is the Hessian matrix, $\tr$ is the matrix trace, and the second equality uses the identity $\tr(\boldsymbol{A}\boldsymbol{B})=\tr(\boldsymbol{B}\boldsymbol{A})$. To obtain a deterministic result, we take expectations on both sides:

(2) \[ \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B)] \approx \mathcal{L}(\boldsymbol{w}) - \eta\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tr(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H}) \]

We view the right-hand side as a quadratic function in $\eta$, and assume the quadratic coefficient is positive (a stronger assumption is that the $\boldsymbol{H}$ matrix is positive definite). Then we can obtain the minimum point:

(3) \[ \eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g}}{\tr(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})} \]

This is on average the learning rate that minimizes the loss function most rapidly—the theoretical optimal learning rate. Our task is to compute $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$ for specific $\tilde{\boldsymbol{\varphi}}_B$, and then extract its relationship with batch size (i.e., $B$) from the above equation.

Warm-up Exercise#

As our first example, we naturally consider the simplest case: SGD. Here, $\tilde{\boldsymbol{\varphi}}_B=\tilde{\boldsymbol{g}}_B$. We can easily obtain $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]=\boldsymbol{g}$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]=\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B$. Thus:

(4) \[ \eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\tr((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H})} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \tr(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \]

where:

(5) \[ \eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\tr(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}} \]

We can interpret result (4) in several ways. First, it is a monotonically increasing but bounded function, with an upper bound of $\eta_{\max}$. This indicates that the learning rate cannot increase indefinitely—consistent with our intuition compared to simple linear or square root scaling laws. When $B \ll \mathcal{B}_{\text{noise}}$, we have:

(6) \[ \eta^* \approx \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \approx \frac{\eta_{\max}}{\mathcal{B}_{\text{noise}}/B} = \eta_{\max} B / \mathcal{B}_{\text{noise}} \]

This shows that when batch size is relatively small, SGD's learning rate indeed scales linearly with batch size, while also suggesting that $\mathcal{B}_{\text{noise}}$ is a key statistical quantity. However, the definition of $\mathcal{B}_{\text{noise}}$ depends on the Hessian matrix $\boldsymbol{H}$, which is nearly impossible to compute precisely in LLMs. Therefore, in practice, we typically assume it is the identity matrix (or a multiple thereof), yielding a simplified form:

(7) \[ \mathcal{B}_{\text{simple}} = \frac{\tr(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}} \]

This result takes the form of noise intensity ($\tr(\boldsymbol{\Sigma})$) divided by signal intensity ($\boldsymbol{g}^{\top}\boldsymbol{g}$)—essentially the inverse of the signal-to-noise ratio. It indicates that with a smaller signal-to-noise ratio, a larger batch size is required to achieve the same $\eta_{\max}$, which aligns with our intuitive understanding. Note that $\tr(\boldsymbol{\Sigma})$ depends only on the diagonal elements of $\boldsymbol{\Sigma}$, meaning we only need to independently estimate the mean and variance for each parameter—a feasible approach in practice.

Data Efficiency#

Beyond the direct relationship between learning rate and batch size, I believe the derived asymptotic relationship between training data volume and training steps is an essential and insightful aspect that must be studied. In particular, this conclusion seems more general than the learning rate relationship (4), as we will later see that SignSGD also yields a similar form of conclusion, though its learning rate pattern is not given by equation (4).

The original paper's discussion of this part is rather complex; the derivation below has been simplified by the author. Specifically, substituting $\eta^*$ back into $\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{g}}_B)$ yields:

(8) \[ \overline{\Delta\mathcal{L}} = \mathcal{L}(\boldsymbol{w}) - \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \]

where $\Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. How should we interpret this result? First, it is a monotonically increasing function of $B$. When $B\to\infty$, it equals $\Delta\mathcal{L}_{\max}$. In other words, if we could use an infinitely large batch size, the loss reduction per step would be $\Delta\mathcal{L}_{\max}$, requiring the minimum number of training steps, denoted as $S_{\min}$.

If the batch size is finite, the average loss reduction per step is only $\overline{\Delta\mathcal{L}}$. This means that, on average, we need $1 + \mathcal{B}_{\text{noise}}/B$ steps to achieve the reduction that would be achieved in 1 step with infinite batch size. Therefore, to reach the same loss level, we need to train for $S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min}$ steps.

Since the batch size is $B$, we can easily derive that the total training data consumed is $E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}$. This result shows that after increasing the batch size, to achieve the same effect, we also need to appropriately increase the data volume $E$; when $B\to 0$, the required data volume is minimized, at $E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}$. Using these notations, we can write:

(9) \[ \left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1 \]

This is the classical relationship between training data volume and training steps, containing two parameters $S_{\min}$ and $E_{\min}$. We can also experimentally search for multiple $(S,E)$ pairs to fit the above equation, thereby estimating $S_{\min}$ and $E_{\min}$, and subsequently estimating $\mathcal{B}_{\text{noise}} = E_{\min} / S_{\min}$. For more analytical details, please refer back to the previous article "How Should Learning Rate Scale with Batch Size?" or OpenAI's original paper "An Empirical Model of Large-Batch Training".

Analysis Challenges#

The discussion so far remains within the realm of SGD. From a computational perspective, SGD is trivial. The real complexity arises when $\tilde{\boldsymbol{\varphi}}_B$ nonlinearly depends on $\tilde{\boldsymbol{g}}_B$—for instance, SignSGD corresponds to $\newcommand{sign}{\mathop{\text{sign}}}\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$, which is often used as an approximation for Adam in theoretical analyses. A more accurate approximation is SoftSignSGD that incorporates $\epsilon$, which we attempted to analyze in "How Does Adam's Epsilon Affect Learning Rate Scaling Laws?".

In these nonlinear scenarios, computing $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$ is often quite challenging, even if we assume $\tilde{\boldsymbol{g}}_B$ follows a simple normal distribution (note: in the SGD analysis, we did not need to assume any specific distributional form). For example, in the previous article, for SignSGD with $\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$, we went through the following steps to compute $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$:

1. Assume the components of $\tilde{\boldsymbol{g}}_B$ are independent, reducing the problem to the expectation of a single component $\tilde{\varphi}_B=\sign(\tilde{g}_B)$ (not bold);

2. Assume $\tilde{g}_B$ (now a scalar) follows a normal distribution, then compute $\mathbb{E}[\tilde{\varphi}_B]$, with the answer expressed using the $\newcommand{erf}{\mathop{\text{erf}}}\erf$ function;

3. Approximate the $\erf$ function using a function of the form $x/\sqrt{x^2+c}$ to simplify the result.

That is, we had to go through numerous convoluted steps to barely compute an approximate result that could be further analyzed (this process first appeared in Tencent's paper "Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling"). And this is considered relatively simple, because for SoftSignSGD, it becomes even more complex:

1. Assume the components of $\tilde{\boldsymbol{g}}_B$ are independent, reducing the problem to the expectation of a single component $\tilde{\varphi}_B=\newcommand{softsign}{\mathop{\text{softsign}}}\softsign(\tilde{g}_B, \epsilon)$;

2. Approximate the $\softsign$ function using a piecewise linear function to enable integration;

3. Assume $\tilde{g}_B$ follows a normal distribution, combine with the approximation from step 2, and compute $\mathbb{E}[\tilde{\varphi}_B]$, resulting in a complex function containing $\erf$;

4. Approximate the complex function using a function of the form $x/\sqrt{x^2+c}$ to simplify the result.

But that's not all. After expending such great effort and making so many assumptions, we have barely managed to compute $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$. Next, we still need to compute $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$, which is often even more complex (SignSGD is an exception because $\sign(x)^2$ is always 1, making it simpler). However, computational complexity is secondary; the main issue is that these steps appear to have no generalizable pattern, seemingly requiring case-by-case analysis, which is mentally exhausting.

To Be Continued#

To avoid making this article excessively long, we will pause here for now, having briefly reviewed the existing analytical results and computational difficulties. In the next article, I will introduce some attempts I have made to reduce the mental burden in the derivation process.

Citation Information

Original Article: Su Jianlin. Rethinking Learning Rate and Batch Size (Part 1): Current Landscape. Scientific Spaces.

How to cite this translation:

Su, J. Rethinking Learning Rate and Batch Size (Part 1): Current Landscape [Translated by Juanxi Tian]. Scientific Spaces.

BibTeX:

@article{su2025rethinking_lr_bs_part1, title = {Rethinking Learning Rate and Batch Size (Part 1): Current Landscape}, author = {Su, Jianlin}, journal = {Scientific Spaces}, year = {2025}, url = {https://kexue.fm/archives/11260}, note = {Translated by Juanxi Tian (ScalingOpt Team)} }